{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c1e67966",
   "metadata": {},
   "source": [
    "# Mirco-Benchmarking for Transformers\n",
    "\n",
    "This notebook benchmarks the most time consuming components in BERT, GPT-2 and T5 to help you understand its performance. Let's first check our libraries and hardware. If your GPUs are recent models, please make sure your CUDA version is also recent, which may greatly affect the performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "65782c24",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pytorch version\t: 1.13.0a0+08820cb\n",
      "CUDA version\t: 11.7\n",
      "GPU\t\t: NVIDIA GeForce RTX 3090 Ti\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "print('Pytorch version\\t:', torch.__version__)\n",
    "print('CUDA version\\t:', torch.version.cuda)\n",
    "print('GPU\\t\\t:',torch.cuda.get_device_name())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "288535a8",
   "metadata": {},
   "source": [
    "Let's first define a `walltime` method to benchmark Pytorch statements by at least 3 seconds. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d06ae2d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "from torch.utils import benchmark \n",
    "\n",
    "pd.options.display.precision = 3\n",
    "\n",
    "def var_dict(*args):\n",
    "    callers_local_vars = inspect.currentframe().f_back.f_locals.items()\n",
    "    return dict([(name, val) for name, val in callers_local_vars if val is arg][0] \n",
    "                for arg in args)\n",
    "\n",
    "def walltime(stmt, arg_dict, duration=3):\n",
    "    return benchmark.Timer(stmt=stmt, globals=arg_dict).blocked_autorange(\n",
    "        min_run_time=duration).median"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c98b1141",
   "metadata": {},
   "source": [
    "Last install huggingface from source code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bcd79038",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "\n",
    "!git clone https://github.com/huggingface/transformers\n",
    "!cd transformers; pip install .\n",
    "\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41d00a71",
   "metadata": {},
   "source": [
    "## Matrix Multiplication\n",
    "\n",
    "Matrix multiplication is the most used operator in Transformers. Its performance is crucial. Let's test the [TFLOPS](https://en.wikipedia.org/wiki/FLOPS) we can achieve on square matrices. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d3ca0f77",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>n=128</th>\n",
       "      <th>n=512</th>\n",
       "      <th>n=2048</th>\n",
       "      <th>n=8192</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>torch.float32</th>\n",
       "      <td>0.250</td>\n",
       "      <td>21.784</td>\n",
       "      <td>31.282</td>\n",
       "      <td>42.056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>torch.float16</th>\n",
       "      <td>0.402</td>\n",
       "      <td>25.857</td>\n",
       "      <td>71.505</td>\n",
       "      <td>81.314</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               n=128   n=512  n=2048  n=8192\n",
       "torch.float32  0.250  21.784  31.282  42.056\n",
       "torch.float16  0.402  25.857  71.505  81.314"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "matmul_tflops = defaultdict(lambda: {})\n",
    "for n in [128, 512, 2048, 8192]:\n",
    "    for dtype in (torch.float32, torch.float16):\n",
    "        a = torch.randn(n, n, dtype=dtype).cuda()\n",
    "        b = torch.randn(n, n, dtype=dtype).cuda()   \n",
    "        t = walltime('a @ b', var_dict(a, b))\n",
    "        matmul_tflops[f'n={n}'][dtype] = 2*n**3 / t / 1e12\n",
    "        del a, b\n",
    "        \n",
    "pd.DataFrame(matmul_tflops)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f292a91",
   "metadata": {},
   "source": [
    "You can see that the performance increases with the matrix size. If your GPU has [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/), you will see a big performance jump when switching from 32-bit floating points to 16-bit floating points.\n",
    "\n",
    "Next you can find the theory TFLOPS of your GPU from Wikipedia, for example, [Nvidia Tesla](https://en.wikipedia.org/wiki/Ampere_(microarchitecture)), [Nvidia Quadro](https://en.wikipedia.org/wiki/Quadro), [RTX 30xx](https://en.wikipedia.org/wiki/GeForce_30_series), and [RTX 20xx](https://en.wikipedia.org/wiki/GeForce_20_series). Here we list several cards, with their memory information.\n",
    "\n",
    "| Model       | Memory (GB) | Memory Bandwidth (GB/sec) | FP32 TFLOPS | FP16 TFLOPS |\n",
    "| ----------- | ----------- | ------------------------- | ----------- | ----------- |\n",
    "| A100        | 80          | 2039                      | 19.5        | 312         |\n",
    "| V100        | 16          | 900                       | 15.7        | 125         |\n",
    "| A6000       | 48          | 768                       | 38          | 150         |\n",
    "| RTX 3090 TI | 24          | 1008                      | 40          | 160         |\n",
    "\n",
    "If the best TFLOPS number you got is still far away from the theory TFLOPS of your GPU, the performance is likely bottlenecked by the memory bandwidth. To illustrate it, let's benchmark a simple elemental-wise multiplication to show both its TFLOPS with memory bandwidth. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6809d73e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>65536</th>\n",
       "      <th>262144</th>\n",
       "      <th>1048576</th>\n",
       "      <th>4194304</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>TFLOPS</th>\n",
       "      <td>0.015</td>\n",
       "      <td>0.060</td>\n",
       "      <td>0.116</td>\n",
       "      <td>0.107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GB/s</th>\n",
       "      <td>120.687</td>\n",
       "      <td>482.782</td>\n",
       "      <td>927.392</td>\n",
       "      <td>859.757</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        65536    262144   1048576  4194304\n",
       "TFLOPS    0.015    0.060    0.116    0.107\n",
       "GB/s    120.687  482.782  927.392  859.757"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vector = defaultdict(lambda: {})\n",
    "for n in [1024*64, 1024*256, 1024*1024, 1024*1024*4]:\n",
    "    a = torch.randn(n).cuda()\n",
    "    t = walltime('a * 1.2', var_dict(a))\n",
    "    vector[n]['TFLOPS'] = n / t / 1e12\n",
    "    vector[n]['GB/s'] = 8 * n / t / 1e9\n",
    "    \n",
    "pd.DataFrame(vector)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec285e5f",
   "metadata": {},
   "source": [
    "You can see that even for large vectors, the TFLOPS is far far way from GPU peak performance, while the bandwidth may be quite close to its theoretical number.\n",
    "\n",
    "The matrix multiplication performance is a main topic in HPC. There are a large number of research papers. Unfortunately the backend library, cuBLAS, is not open sourced. You may check [cutlass](https://github.com/NVIDIA/cutlass), which claimed similar performance as cuBLAS, for some implementation details.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c13b71",
   "metadata": {},
   "source": [
    "## BERT Layer\n",
    "\n",
    "The main body of a Transformer model is a stacking of Transformer blocks. Let's benchmark the performance of a single block. In BERT, it is often called a BERT layer. Let's construct one such layer from the [BERT large model](https://huggingface.co/bert-large-uncased). We use 16-bit floating points for better performance. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8c9957b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoConfig, BertLayer\n",
    "\n",
    "config = AutoConfig.from_pretrained(\"bert-large-uncased\")\n",
    "layer = BertLayer(config).half().cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b2315ab",
   "metadata": {},
   "source": [
    "Then define a function to benchmark both forward and forward with backward performance using different sequence lengths and batch sizes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5f7f89c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def layer_benchmark(layer, hidden_size, seq_lens, batch_sizes, cross_attention=False):\n",
    "    h = hidden_size\n",
    "    results = defaultdict(lambda: {})    \n",
    "    encoder_state = 'encoder_hidden_states=X' if cross_attention else ''\n",
    "    for s in seq_lens:\n",
    "        for b in batch_sizes:            \n",
    "            ffn = 16*b*s*h*h / 1e12  # TFLOPS for the Feed-Forward Network\n",
    "            atten = (4*b*h*s*s + 8*b*s*h*h) / 1e12  # TFLOPS for attention            \n",
    "            forward = ffn + (2 if cross_attention else 1) * atten\n",
    "            \n",
    "            X = torch.randn(b, s, h).half().cuda()\n",
    "            results[f'batch={b}'][f'fwd seq_len={s}'] = forward / walltime(\n",
    "                f'layer(X, {encoder_state})', var_dict(layer, X))\n",
    "            results[f'batch={b}'][f'fwd+bwd seq_len={s}'] = 3 * forward / walltime(\n",
    "                f'layer(X, {encoder_state})[0].sum().backward()', var_dict(layer, X))            \n",
    "    return pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9116be57",
   "metadata": {},
   "source": [
    "In BERT pre-training, we often train with a sequence of 128 (stage 1) or 512 (stage 2). Let's test its performance. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1e278b06",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>batch=2</th>\n",
       "      <th>batch=4</th>\n",
       "      <th>batch=8</th>\n",
       "      <th>batch=16</th>\n",
       "      <th>batch=32</th>\n",
       "      <th>batch=64</th>\n",
       "      <th>batch=128</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>fwd seq_len=128</th>\n",
       "      <td>15.182</td>\n",
       "      <td>29.216</td>\n",
       "      <td>45.127</td>\n",
       "      <td>45.951</td>\n",
       "      <td>50.962</td>\n",
       "      <td>51.905</td>\n",
       "      <td>56.488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd+bwd seq_len=128</th>\n",
       "      <td>14.643</td>\n",
       "      <td>24.224</td>\n",
       "      <td>45.810</td>\n",
       "      <td>50.264</td>\n",
       "      <td>55.816</td>\n",
       "      <td>58.354</td>\n",
       "      <td>61.768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd seq_len=512</th>\n",
       "      <td>39.007</td>\n",
       "      <td>40.479</td>\n",
       "      <td>43.847</td>\n",
       "      <td>44.570</td>\n",
       "      <td>47.698</td>\n",
       "      <td>48.522</td>\n",
       "      <td>46.998</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd+bwd seq_len=512</th>\n",
       "      <td>40.180</td>\n",
       "      <td>44.317</td>\n",
       "      <td>48.017</td>\n",
       "      <td>49.961</td>\n",
       "      <td>52.374</td>\n",
       "      <td>53.846</td>\n",
       "      <td>52.613</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     batch=2  batch=4  batch=8  batch=16  batch=32  batch=64  \\\n",
       "fwd seq_len=128       15.182   29.216   45.127    45.951    50.962    51.905   \n",
       "fwd+bwd seq_len=128   14.643   24.224   45.810    50.264    55.816    58.354   \n",
       "fwd seq_len=512       39.007   40.479   43.847    44.570    47.698    48.522   \n",
       "fwd+bwd seq_len=512   40.180   44.317   48.017    49.961    52.374    53.846   \n",
       "\n",
       "                     batch=128  \n",
       "fwd seq_len=128         56.488  \n",
       "fwd+bwd seq_len=128     61.768  \n",
       "fwd seq_len=512         46.998  \n",
       "fwd+bwd seq_len=512     52.613  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer_benchmark(layer, config.hidden_size, [128, 512], [2, 4, 8, 16, 32, 64, 128])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "889961fa",
   "metadata": {},
   "source": [
    "No surprise that a large batch size helps. But the best number is below the matrix multiplication TFLOPS. Let's find why.\n",
    "\n",
    "We first benchmark the first dense layer in the Feed-Forward Network (FFN) in the layer. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c39f6f9f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Dense layer TFLOPS: 74.282'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h, b, s = config.hidden_size, 64, 128\n",
    "X = torch.randn(b, s, h).half().cuda()\n",
    "\n",
    "'Dense layer TFLOPS: %.3f' % (8*b*s*h*h / 1e12 / walltime(    \n",
    "    'layer.intermediate.dense(X)', var_dict(layer, X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cea6579",
   "metadata": {},
   "source": [
    "The number is pretty good. Then run this dense layer with the GeLU activation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "44620688",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Dense+Activation TFLOPS: 63.994'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'Dense+Activation TFLOPS: %.3f' % (8*b*s*h*h / 1e12 / walltime(\n",
    "    'layer.intermediate(X)', var_dict(layer, X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d591ed3c",
   "metadata": {},
   "source": [
    "Even the activation function has a ignorable complexity, it brings down the TFLOPS. We pointed out the reason before, the elemental-wise operation of the activation function is bounded by the memory bandwidth.\n",
    "\n",
    "Now test the whole FFN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a6837160",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'FFN TFLOPS: 59.793'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ffn = 16*b*s*h*h / 1e12\n",
    "'FFN TFLOPS: %.3f'%(ffn / walltime(\n",
    "    'layer.output(layer.intermediate(X),X)', var_dict(layer, X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59214b42",
   "metadata": {},
   "source": [
    "The other part in the BERT layer is the multi-head self-attention."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b2b4e48d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Attention TFLOPS: 41.050'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "att = (4*b*h*s*s + 8*b*s*h*h) / 1e12\n",
    "'Attention TFLOPS: %.3f'%(\n",
    "    att / walltime('layer.attention(X)', var_dict(layer, X)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7eec79b",
   "metadata": {},
   "source": [
    "Even though the main computation part of the attention block is still matrix multiplication, it has more memory bounded operators compared to FFN. So you see a lower TFLOPS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "02d0e4df",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.53125"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "att / ffn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1daaaf4e",
   "metadata": {},
   "source": [
    "The ratio of complexity between attention and FFN depends on the BERT configuration. The overall performance is a weighted sum between the FLOPS of these two components."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32888ee3",
   "metadata": {},
   "source": [
    "## GPT-2 Block\n",
    "\n",
    "Next let's evaluate `gpt2-medium`, which has a similar architecture has `bert-large`, i.e. 24 layers with a 1024 hidden size. GPT2 is trained with a 1024 sequence length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "3f889cb3",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>batch=2</th>\n",
       "      <th>batch=4</th>\n",
       "      <th>batch=8</th>\n",
       "      <th>batch=16</th>\n",
       "      <th>batch=32</th>\n",
       "      <th>batch=64</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>fwd seq_len=512</th>\n",
       "      <td>28.082</td>\n",
       "      <td>30.525</td>\n",
       "      <td>33.745</td>\n",
       "      <td>35.026</td>\n",
       "      <td>36.302</td>\n",
       "      <td>36.595</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd+bwd seq_len=512</th>\n",
       "      <td>29.920</td>\n",
       "      <td>33.389</td>\n",
       "      <td>35.788</td>\n",
       "      <td>37.137</td>\n",
       "      <td>38.818</td>\n",
       "      <td>39.371</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd seq_len=1024</th>\n",
       "      <td>26.884</td>\n",
       "      <td>29.189</td>\n",
       "      <td>30.049</td>\n",
       "      <td>30.911</td>\n",
       "      <td>31.087</td>\n",
       "      <td>30.623</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd+bwd seq_len=1024</th>\n",
       "      <td>29.436</td>\n",
       "      <td>31.274</td>\n",
       "      <td>32.150</td>\n",
       "      <td>33.323</td>\n",
       "      <td>33.678</td>\n",
       "      <td>32.888</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      batch=2  batch=4  batch=8  batch=16  batch=32  batch=64\n",
       "fwd seq_len=512        28.082   30.525   33.745    35.026    36.302    36.595\n",
       "fwd+bwd seq_len=512    29.920   33.389   35.788    37.137    38.818    39.371\n",
       "fwd seq_len=1024       26.884   29.189   30.049    30.911    31.087    30.623\n",
       "fwd+bwd seq_len=1024   29.436   31.274   32.150    33.323    33.678    32.888"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers.models.gpt2.modeling_gpt2 import GPT2Block\n",
    "\n",
    "config = AutoConfig.from_pretrained(\"gpt2-medium\")\n",
    "layer = GPT2Block(config, layer_idx=0).half().cuda()\n",
    "layer_benchmark(layer, config.n_embd, [512, 1024], [2, 4, 8, 16, 32, 64])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "974c8cd5",
   "metadata": {},
   "source": [
    "You can see that, despite GPT-2 and BERT has the same complexity, GPT-2 has slightly worse TFLOPS when using the same batch size and sequence length. Also using a larger sequence length 1024 further harms the performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e285d9d",
   "metadata": {},
   "source": [
    "## T5 Layer\n",
    "\n",
    "T5 has both encoder and decoder, let's first benchmark the decoder, whose performance is similar to BERT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "74231af6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>batch=2</th>\n",
       "      <th>batch=4</th>\n",
       "      <th>batch=8</th>\n",
       "      <th>batch=16</th>\n",
       "      <th>batch=32</th>\n",
       "      <th>batch=64</th>\n",
       "      <th>batch=128</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>fwd seq_len=512</th>\n",
       "      <td>24.755</td>\n",
       "      <td>28.931</td>\n",
       "      <td>31.319</td>\n",
       "      <td>33.559</td>\n",
       "      <td>35.260</td>\n",
       "      <td>35.832</td>\n",
       "      <td>35.917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd+bwd seq_len=512</th>\n",
       "      <td>30.311</td>\n",
       "      <td>34.269</td>\n",
       "      <td>36.650</td>\n",
       "      <td>38.829</td>\n",
       "      <td>40.165</td>\n",
       "      <td>41.210</td>\n",
       "      <td>40.916</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     batch=2  batch=4  batch=8  batch=16  batch=32  batch=64  \\\n",
       "fwd seq_len=512       24.755   28.931   31.319    33.559    35.260    35.832   \n",
       "fwd+bwd seq_len=512   30.311   34.269   36.650    38.829    40.165    41.210   \n",
       "\n",
       "                     batch=128  \n",
       "fwd seq_len=512         35.917  \n",
       "fwd+bwd seq_len=512     40.916  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers.models.t5.modeling_t5 import T5Block\n",
    "\n",
    "config = AutoConfig.from_pretrained(\"t5-large\")\n",
    "config.use_cache = False\n",
    "config.is_decoder = False\n",
    "config.is_encoder_decoder = False\n",
    "\n",
    "encoder = T5Block(config).half().cuda()\n",
    "layer_benchmark(encoder, config.d_model, [512], [2, 4, 8, 16, 32, 64, 128])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19679764",
   "metadata": {},
   "source": [
    "The decoder has an additional cross attention, which increases the time complexity and also hurts TFLOPS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c9a57c27",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>batch=2</th>\n",
       "      <th>batch=4</th>\n",
       "      <th>batch=8</th>\n",
       "      <th>batch=16</th>\n",
       "      <th>batch=32</th>\n",
       "      <th>batch=64</th>\n",
       "      <th>batch=128</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>fwd seq_len=512</th>\n",
       "      <td>21.37</td>\n",
       "      <td>24.679</td>\n",
       "      <td>26.641</td>\n",
       "      <td>28.612</td>\n",
       "      <td>29.875</td>\n",
       "      <td>30.368</td>\n",
       "      <td>30.349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>fwd+bwd seq_len=512</th>\n",
       "      <td>27.04</td>\n",
       "      <td>30.315</td>\n",
       "      <td>32.287</td>\n",
       "      <td>34.087</td>\n",
       "      <td>35.060</td>\n",
       "      <td>35.986</td>\n",
       "      <td>35.852</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     batch=2  batch=4  batch=8  batch=16  batch=32  batch=64  \\\n",
       "fwd seq_len=512        21.37   24.679   26.641    28.612    29.875    30.368   \n",
       "fwd+bwd seq_len=512    27.04   30.315   32.287    34.087    35.060    35.986   \n",
       "\n",
       "                     batch=128  \n",
       "fwd seq_len=512         30.349  \n",
       "fwd+bwd seq_len=512     35.852  "
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config.is_decoder = True\n",
    "decoder = T5Block(config).half().cuda()\n",
    "layer_benchmark(decoder, config.d_model, [512], [2, 4, 8, 16, 32, 64, 128], cross_attention=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d1a2765",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "To conclude, to achieve the best performance for a Transformer layer, you need to use a fast data type and a large batch size. For further improvement, we may need to rewrite the code. For example, [fusing](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#fuse-pointwise-operations) multiple kernels into a single one. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
